
# 使用Pytorch中常规层实现单层的LSTM

# 参考: https://blog.csdn.net/luoganttcc/article/details/105959310
# 此参考连接中直接在网络中保存权重的方式，在转caffe时会丢失权重，故此方法不行

# 尝试使用PyTorch的线性层linear保存权值、与输入及状态做矩阵乘法，可行

import os
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import math

'''
# 参考: https://blog.csdn.net/luoganttcc/article/details/105959310
# 此参考连接中直接在网络中保存权重的方式，在转caffe时会丢失权重，故此方法不行
class NaiveLSTM(nn.Module):
    """Naive LSTM like nn.LSTM"""
    def __init__(self, input_size: int, hidden_size: int):
        super(NaiveLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # PyTorch官方LSTM中：
        #   w_ii、w_if、w_io、w_ig为同一个权重，b_ii、b_if、b_io、b_ig为同一个偏置
        #   w_hi、w_hf、w_ho、w_hg为同一个权重，b_hi、b_hf、b_ho、b_hg为同一个偏置
        
        # self.w_ii = Parameter(Tensor(hidden_size, input_size))
        # self.w_hi = Parameter(Tensor(hidden_size, hidden_size))
        # self.b_ii = Parameter(Tensor(hidden_size, 1))
        # self.b_hi = Parameter(Tensor(hidden_size, 1))
        self.w_ii = Parameter(Tensor(input_size, input_size))
        self.w_hi = Parameter(Tensor(input_size, hidden_size))
        self.b_ii = Parameter(Tensor(input_size, 1))
        self.b_hi = Parameter(Tensor(input_size, 1))

        self.reset_weigths()

    def reset_weigths(self):
        """reset weights
        """
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            init.uniform_(weight, -stdv, stdv)

    # def forward(self, inputs: Tensor, state: Tuple[Tensor]) \
    #     -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
    def forward(self, inputs: Tensor, state):
        """Forward
        Args:
            inputs: [1, 1, input_size]
            state: ([1, 1, hidden_size], [1, 1, hidden_size])
        """
        # !!!单元数为1，只处理一个时间节点

        if state is None: # 输入无初始状态，默认为0
            h_t = torch.zeros(1, self.hidden_size).t()
            c_t = torch.zeros(1, self.input_size).t()
        else:
            (h, c) = state
            h_t = h.squeeze(0).t()
            c_t = c.squeeze(0).t()

        # hidden_seq = []
        x = inputs[:, 0, :].t()
        temp1 = self.w_ii @ x + self.b_ii       # PyTorch中 @-矩阵相乘
        temp2 = self.w_hi @ h_t + self.b_hi

        i = torch.sigmoid(temp1 + temp2) # i=f=o
        g = torch.tanh(temp1 + temp2)
        
        c_next = i * c_t + i * g        # PyTorch中 *-矩阵点乘
        h_next = i * torch.tanh(c_next)
        c_next_t = c_next.t().unsqueeze(0)
        h_next_t = h_next.t().unsqueeze(0)
        # hidden_seq.append(h_next_t)
        hidden_seq = h_next_t

        # hidden_seq = torch.cat(hidden_seq, dim=0)
        return hidden_seq, (h_next_t, c_next_t)
'''


class NaiveLSTM(nn.Module):
    """Naive LSTM like nn.LSTM"""
    def __init__(self, input_size: int, hidden_size: int):
        super(NaiveLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # PyTorch官方LSTM中：
        #   w_ii、w_if、w_io、w_ig为同一个权重，b_ii、b_if、b_io、b_ig为同一个偏置
        #   w_hi、w_hf、w_ho、w_hg为同一个权重，b_hi、b_hf、b_ho、b_hg为同一个偏置
        self.linear1 = nn.Linear(input_size, 4*hidden_size) # 对应 w_ii b_ii这一部分
        self.linear2 = nn.Linear(hidden_size, 4*hidden_size) # 对应 w_hi b_hi这一部分
        # 线性层：理论为W*x + b，实际实现为：x*W + b， 故需要在输入时进行转置
        
        # 激活函数在转caffe时失败，解决方法：https://github.com/xxradon/PytorchToCaffe/issues/64
        # self.sigmod1 = nn.Sigmoid()
        # self.tanh1 = nn.Tanh()
        # self.tanh2 = nn.Tanh()

    # def forward(self, inputs: Tensor, state: Tuple[Tensor]) \
    #     -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
    def forward(self, inputs: Tensor, state):
        """Forward
        Args:
            inputs: [1, 1, input_size]
            state: ([1, 1, hidden_size], [1, 1, hidden_size])
        """
        # !!!单元数为1，只处理一个时间节点

        if state is None: # 输入无初始状态，默认为0
            h_t = torch.zeros(self.hidden_size, 1).t()
            c_t = torch.zeros(self.hidden_size, 1).t()
        else:
            (h, c) = state
            h_t = h.squeeze(0).reshape(1, -1)
            c_t = c.squeeze(0).reshape(1, -1)

        # x = inputs[:, 0, :].t()
        x = inputs[:, 0, :]

        temp1 = self.linear1(x)
        temp2 = self.linear2(h_t)
        temp1 = temp1.reshape(-1, self.hidden_size)
        temp1 = torch.sum(temp1, dim=0)
        temp2 = temp2.reshape(-1, self.hidden_size)
        temp2 = torch.sum(temp2, dim=0)
        temp3 = temp1 + temp2
        # i = self.sigmod1(temp3) # i=f=o
        # g = self.tanh1(temp3)
        i = nn.functional.sigmoid(temp3) # i=f=o
        g = nn.functional.tanh(temp3)
        
        c_next = i * c_t + i * g        # PyTorch中 * 矩阵点乘
        # h_next = i * self.tanh2(c_next)
        h_next = i * nn.functional.tanh(c_next)

        c_next_t = c_next.t().unsqueeze(0)
        h_next_t = h_next.t().unsqueeze(0)
        # hidden_seq.append(h_next_t)
        hidden_seq = h_next_t

        # hidden_seq = torch.cat(hidden_seq, dim=0)
        return hidden_seq, (h_next_t, c_next_t)



### test 
if __name__ == '__main__':
    input_size = 2048 # LSTM输入尺寸
    hidden_size= 512  # 隐藏层尺寸

    inputs = torch.ones(1, 1, input_size)
    h0 = torch.ones(1, 1, hidden_size)
    c0 = torch.ones(1, 1, hidden_size)
    print("h0.shape: ", h0.shape)
    print("c0.shape: ", c0.shape)
    print("inputs.shape: ", inputs.shape)

    # 创建网络
    naive_lstm = NaiveLSTM(input_size=input_size, hidden_size=hidden_size)

    output1, (hn1, cn1) = naive_lstm(inputs, (h0, c0)) # 检查网络运行
    print(hn1.shape, cn1.shape, output1.shape)

    # naive_lstm.linear1.weight = Parameter(torch.ones(input_size, input_size)) # 可以按此方式修改网络参数
    
    # 加载待转换的权重
    weight_path = r"D:\\Works\\Hisi_codes\\SLR-master\\res50lstm_models\\res50lstm_epoch026-lstm.pth"
    print("Load params from : ", weight_path)
    model = torch.load(weight_path)

    weights_keys = [] # 待转换权重的键名
    for k in model.keys(): # 查看每个键
        # print(k)        # 键名称
        weights_keys.append(k)
        # print(model[k]) # 键值

    # 将网络权重赋给当前LSTM网络
    naive_lstm.linear1.weight = Parameter(model[weights_keys[0]])
    naive_lstm.linear2.weight = Parameter(model[weights_keys[1]])
    naive_lstm.linear1.bias = Parameter(model[weights_keys[2]])
    naive_lstm.linear2.bias = Parameter(model[weights_keys[3]])
    # naive_lstm.linear1.bias = Parameter(naive_lstm.linear1.bias.reshape([1, input_size])) # 会压缩维度，再转换回来
    # naive_lstm.linear2.bias = Parameter(naive_lstm.linear1.bias.reshape([1, input_size]))

    # 查看整个网络的参数
    print("naive_lstm.state_dict():\n", naive_lstm.state_dict())

    # 保存
    print("save....")
    model_path = r"D:\\Works\\Hisi_codes\\SLR-master\\res50lstm_models"
    torch.save(naive_lstm.state_dict(), os.path.join(model_path, "lstm_with_linear.pth"))

